{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from target_assign_rl import IQLAgent, ReplayBuffer, raw_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_iql(env, num_episodes, batch_size=256, update_target_every=500, save_every=2000):\n",
    "    env.reset()\n",
    "    state_dim = env.state().shape[0]\n",
    "    action_dim = env.action_space(env.agents[0]).n\n",
    "\n",
    "    rl_agent = IQLAgent(state_dim, action_dim)\n",
    "    replay_buffer = ReplayBuffer(100000)\n",
    "\n",
    "    last_avg_reward = -np.inf\n",
    "    episode_rewards = []\n",
    "    episode_losses = []\n",
    "    episode_epsilons = []\n",
    "    episode_kd = []\n",
    "    episode_kill = []\n",
    "    episode_remain = []\n",
    "    episode_lost = []\n",
    "\n",
    "    for episode in range(num_episodes):\n",
    "        env.reset()\n",
    "        episode_reward = 0\n",
    "        episode_loss = 0\n",
    "\n",
    "        states = []\n",
    "        actions = []\n",
    "        next_states = []\n",
    "        dones = []\n",
    "        for agent in env.agents:\n",
    "            agent_state = env.state()\n",
    "            states.append(agent_state)\n",
    "            action_mask = env.action_mask(agent)\n",
    "            action = rl_agent.select_action(agent_state, action_mask)\n",
    "            actions.append(action)\n",
    "            env.step(action)\n",
    "            next_states.append(env.state())\n",
    "            dones.append(False)\n",
    "\n",
    "        _, reward, done, _, info = env.last()\n",
    "        dones[-1] = done\n",
    "        for i in range(len(env.agents)):\n",
    "            replay_buffer.push(states[i], actions[i], reward, next_states[i], dones[i])\n",
    "        episode_reward += reward\n",
    "\n",
    "        if len(replay_buffer) > batch_size * 4:\n",
    "            batch = replay_buffer.sample(batch_size)\n",
    "            loss = rl_agent.update(batch)\n",
    "            episode_loss += loss\n",
    "            rl_agent.update_epsilon()\n",
    "            # replay_buffer.clear()\n",
    "\n",
    "        if episode % update_target_every == 0:\n",
    "            rl_agent.update_target_network()\n",
    "\n",
    "        episode_rewards.append(episode_reward)\n",
    "        episode_losses.append(episode_loss)\n",
    "        episode_epsilons.append(rl_agent.epsilon)\n",
    "        episode_kd.append(info[\"kd_ratio\"])\n",
    "        episode_kill.append(info[\"threat_destroyed\"])\n",
    "        episode_remain.append(info[\"num_remaining_threat\"])\n",
    "        episode_lost.append(info[\"drone_lost\"])\n",
    "\n",
    "        if (episode + 1) % save_every == 0:\n",
    "            avg_reward = np.mean(episode_rewards[-save_every:])\n",
    "            avg_loss = np.mean(episode_losses[-save_every:])\n",
    "            print(\n",
    "                f\"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Avg Loss: {avg_loss:.4f}, Epsilon: {rl_agent.epsilon:.2f}\"\n",
    "            )\n",
    "            if avg_reward > last_avg_reward:\n",
    "                last_avg_reward = avg_reward\n",
    "                rl_agent.save_checkpoint(episode + 1)\n",
    "\n",
    "    training_data = {\n",
    "        \"rewards\": episode_rewards,\n",
    "        \"losses\": episode_losses,\n",
    "        \"kd_ratio\": episode_kd,\n",
    "        \"threat_destroyed\": episode_kill,\n",
    "        \"threat_remain\": episode_remain,\n",
    "        \"drone_lost\": episode_lost,\n",
    "    }\n",
    "\n",
    "    np.save(\"training_data.npy\", training_data)\n",
    "    rl_agent.save_checkpoint(num_episodes)\n",
    "\n",
    "    return rl_agent, training_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = raw_env(\n",
    "    dict(\n",
    "        min_drones=20,\n",
    "        possible_level=[0, 0.05, 0.1, 0.5, 0.8],\n",
    "        threat_dist=[0.1, 0.3, 0.1, 0.35, 0.15],\n",
    "        attack_prob=0.6,\n",
    "    )\n",
    ")\n",
    "trained_agent, training_data = train_iql(env, num_episodes=int(1e6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_moving_average(df, column, window=1000, figsize=(12, 6)):\n",
    "    \"\"\"绘制移动平均线图\"\"\"\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.plot(df[\"episode\"], df[column], alpha=0.3, label=\"Raw\")\n",
    "    plt.plot(\n",
    "        df[\"episode\"],\n",
    "        df[column].rolling(window=window).mean(),\n",
    "        label=f\"{window}-episode Moving Average\",\n",
    "    )\n",
    "    plt.title(f\"{column.capitalize()} over Episodes\")\n",
    "    plt.xlabel(\"Episode\")\n",
    "    plt.ylabel(column.capitalize())\n",
    "    plt.legend()\n",
    "    plt.grid(True, alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"{column}_moving_average.png\")\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_multi_metric_comparison(df, metrics, figsize=(12, 6), normalize=True):\n",
    "    \"\"\"绘制多指标对比图\"\"\"\n",
    "    plt.figure(figsize=figsize)\n",
    "    for metric in metrics:\n",
    "        plt.plot(\n",
    "            df[\"episode\"],\n",
    "            (df[metric] - df[metric].min()) / (df[metric].max() - df[metric].min()) if normalize else df[metric],\n",
    "            label=metric,\n",
    "        )\n",
    "    plt.title(f\"{'Normalized' if normalize else 'Raw'} Metrics Comparison\")\n",
    "    plt.xlabel(\"Episode\")\n",
    "    plt.ylabel(f\"{'Normalized' if normalize else 'Raw'} Value\")\n",
    "    plt.legend()\n",
    "    plt.grid(True, alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "\n",
    "    metrics_str = \"_\".join(metrics)\n",
    "    plt.savefig(f\"{metrics_str}_comparison.png\")\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_correlation_heatmap(df, figsize=(10, 8)):\n",
    "    \"\"\"绘制相关性热力图\"\"\"\n",
    "    corr = df.corr()\n",
    "    plt.figure(figsize=figsize)\n",
    "    sns.heatmap(corr, annot=True, cmap=\"coolwarm\", vmin=-1, vmax=1, center=0)\n",
    "    plt.title(\"Correlation Heatmap of Training Metrics\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"correlation_heatmap.png\")\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def analyze_training_data(data, normalize=True):\n",
    "    \"\"\"综合分析训练数据\"\"\"\n",
    "    df = pd.DataFrame(data)\n",
    "    df['episode'] = range(1, len(df) + 1)\n",
    "\n",
    "    # 绘制移动平均线图\n",
    "    plot_moving_average(df, \"rewards\")\n",
    "    plot_moving_average(df, \"losses\")\n",
    "    plot_moving_average(df, \"kd_ratio\")\n",
    "\n",
    "    # 绘制多指标对比图\n",
    "    plot_multi_metric_comparison(df, [\"rewards\", \"losses\", \"kd_ratio\"], normalize=normalize)\n",
    "\n",
    "    # 绘制相关性热力图\n",
    "    plot_correlation_heatmap(df)\n",
    "\n",
    "    print(\"分析完成，所有图表已保存。\")\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training_data = np.load(\"training_data.npy\", allow_pickle=True).item()\n",
    "\n",
    "df = analyze_training_data(training_data)\n",
    "df_mean = df.groupby(np.arange(len(df))//1000).mean()\n",
    "df_mean.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_multi_metric_comparison(df_mean, [\"threat_destroyed\", \"threat_remain\", \"drone_lost\"], normalize=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "driving_gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
